package main

import (
	"bytes"
	"crypto/sha256"
	"errors"

	//"fmt"
	"sync"
)

type MerkleTree struct {
	level int        //层数
	fork  []int      //分叉数
	ln    []int      //每层的节点数
	node  [][]Sha256 //哈希值
}

type MerkleProof struct {
	Fork  []int
	Proof [][]Sha256
}

type Sha256 = []byte

//自定义初始化，fork为Merkel树每层的分叉数，其中第一层为树根必须为1。其余层只要叶子节点数能除尽数据长度就行。
func (m *MerkleTree) InitConfig(fork []int) error {
	m.fork = fork
	m.level = len(m.fork)
	m.ln = make([]int, m.level)
	m.ln[0] = 1
	for i := 1; i < m.level; i++ {
		m.ln[i] = m.ln[i-1] * m.fork[i]
	}
	return nil
}

func (m *MerkleTree) Init() { //默认初始化
	m.level = 4
	m.ln = []int{1, 4, 64, 1024}
	m.fork = []int{1, 4, 16, 16}
}

func (m *MerkleTree) GenerateTree(Data []byte) error {
	unit := len(Data) / m.ln[m.level-1]
	if len(Data)%m.ln[m.level-1] != 0 {
		return errors.New("请补足数据")
	}

	//var lh []Sha256
	lh := make([]Sha256, m.ln[m.level-1])
	m.node = make([][][]byte, m.level)
	var wg sync.WaitGroup
	wg.Add(m.ln[m.level-1]) //
	for i := 0; i < m.ln[m.level-1]; i++ {
		//wg.Add(1)
		go func(i int, unit int, Data []byte, lh []Sha256) {
			defer wg.Done()
			h := sha256.New()
			h.Write(Data[i*unit : (i+1)*unit])
			lh[i] = h.Sum(nil)
		}(i, unit, Data, lh)
		// h := sha256.New()
		// h.Write(Data[i*unit : (i+1)*unit])
		// //lh = append(lh, h.Sum(nil))
		// lh[i] = h.Sum(nil)
	}
	wg.Wait()

	m.node[m.level-1] = lh

	for i := m.level - 1; i > 0; i-- { //计算第i层
		//var lh []Sha256
		//fmt.Println(m.ln[i-1])
		var wg sync.WaitGroup
		lh := make([]Sha256, m.ln[i-1])
		wg.Add(m.ln[i-1])
		for j := 0; j < m.ln[i-1]; j++ { //计算第i层第j个
			//wg.Add(1)
			go func(i int, j int, lh []Sha256) {
				defer wg.Done()
				h := sha256.New()
				for k := 0; k < m.fork[i]; k++ { //遍历第i层第j个的k个下属
					//fmt.Println(i, j, k)
					h.Write(m.node[i][j*m.fork[i]+k])
				}
				//lh = append(lh, h.Sum(nil))
				lh[j] = h.Sum(nil)
			}(i, j, lh)
			// h := sha256.New()
			// for k := 0; k < m.fork[i]; k++ { //遍历第i层第j个的k个下属
			// 	//fmt.Println(i, j, k)
			// 	h.Write(m.node[i][j*m.fork[i]+k])
			// }
			// //lh = append(lh, h.Sum(nil))
			// lh[j] = h.Sum(nil)
		}
		wg.Wait()
		m.node[i-1] = lh
	}

	return nil
}

func (m *MerkleTree) Serial() (b []byte) {
	for _, v := range m.node {
		for _, j := range v {
			b = append(b, j...)
		}
	}
	return b
}

func (m *MerkleTree) Root() (r []byte) {
	return m.node[0][0]
}

func (m *MerkleTree) GenerateMerkleProof(location int) (proof MerkleProof, err error) {
	//fmt.Println("生成证明")
	if location > m.ln[m.level-1] {
		return proof, errors.New("位置越界")
	}
	proof.Fork = m.fork
	for i := m.level - 1; i >= 0; i-- {
		// for j := 0; j < proof.fork[i]; j++{
		location = location / m.fork[i]
		lh := m.node[i][location*m.fork[i] : (location+1)*m.fork[i]]
		proof.Proof = append(proof.Proof, lh)
		// }
	}
	return proof, nil
}

func (p *MerkleProof) Check(location int) bool {
	for i := 0; i < len(p.Proof)-1; i++ {
		//fmt.Println("验证", i)
		h := sha256.New()
		for j := 0; j < len(p.Proof[i]); j++ {
			h.Write(p.Proof[i][j])
		}
		hash := h.Sum(nil)
		//fmt.Println(location, p.fork[len(p.fork)-i-2], location%p.fork[len(p.fork)-i-2])
		if !bytes.Equal(hash, p.Proof[i+1][(location/p.Fork[len(p.Fork)-i-1])%p.Fork[len(p.Fork)-i-2]]) { //从下一层计算出在上一层中的位置，
			return false
		}
		location = location / p.Fork[len(p.Fork)-i-1]
	}

	return true
}
